Conversation
| ) | ||
| if ( | ||
| hasattr(self, "_inference_token_dispatcher") | ||
| and self.is_inference_cuda_graphed_iteration |
There was a problem hiding this comment.
[SUGGESTION] The log_overload_factor block inside experts_compute_dispatch (balanced count retrieval, dispatcher type check, tensor construction, hook registration) should be extracted into a private method such as _record_overload_factor(self, dispatched_input, tokens_per_expert).
Mixing this logic into experts_compute_dispatch hurts readability. A single call site keeps the dispatch method focused on dispatch logic.
| if rm is not None: | ||
| return rm | ||
| cm = getattr(td, "_comm_manager", None) | ||
| if cm is not None: |
There was a problem hiding this comment.
[SUGGESTION] Avoid abbreviated variable names in parallel/distributed code where clarity is critical:
td→token_dispatcherrm→routing_mapcm→comm_manager
Same applies to td (L501) and ws (L495) in experts_compute_dispatch. Use tp_ep_world_size for ws.
| Flex/HybridEP keep the map on ``_comm_manager``. | ||
| """ | ||
| td = self.token_dispatcher | ||
| rm = getattr(td, "routing_map", None) |
There was a problem hiding this comment.
[SUGGESTION] The balanced token count (routing_map.shape[0] * topk) could be computed earlier in MoELayer.forward() directly from hidden_states before any dispatch, rather than reading routing_map in this post-token_dispatch window.
Current approach has two fragilities:
- Requires
routing_mapto still be alive aftertoken_dispatchbut beforedispatch_postprocessclears it — a timing assumption tied to dispatcher internals. - Needs separate handling for AllGather (
routing_mapattr) vs Flex/HybridEP (_comm_manager.routing_map) dispatchers — coupling to internal implementation details.
Computing from hidden_states.shape[0] in MoELayer.forward() would remove _routing_map_after_token_dispatch entirely.
| local_balanced = torch.empty( | ||
| (), device=dispatched_input.device, dtype=torch.float32 | ||
| ) | ||
| local_balanced.fill_(base) |
There was a problem hiding this comment.
[SUGGESTION] torch.empty(()) + fill_() is unnecessarily verbose. Use torch.tensor directly:
local_balanced = torch.tensor(base, device=dispatched_input.device, dtype=torch.float32)| tokens_on_rank = tokens_per_expert.detach().sum() | ||
| if not tokens_on_rank.is_floating_point(): | ||
| tokens_on_rank = tokens_on_rank.float() | ||
| tokens_on_rank = tokens_on_rank.to(device=tensor.device, dtype=torch.float32).reshape(()) |
There was a problem hiding this comment.
[SUGGESTION] .reshape(()) is a no-op here — tokens_per_expert.detach().sum() already returns a 0-dim tensor. Same applies to balanced on L1016. Both .reshape(()) calls can be removed.
| device=device, | ||
| dtype=torch.float32, | ||
| ) | ||
| torch.distributed.all_reduce( |
There was a problem hiding this comment.
[SUGGESTION] report() contains 6–7 sequential all_reduce calls across the same groups. Some can be fused to reduce collective launch overhead:
- tp_ep:
all_reduce(max_actual, MAX)(L283) andall_reduce(balanced_stacked, SUM)(L295) can be packed into one call by stacking both tensors. - dp:
all_reduce(overload_avg, AVG)(L303) andall_reduce(overload_max, MAX)(L307) operate on the same data — consider a single fused reduce. ratio_t: two sequential all_reduces (tp_ep MAX L268, dp MAX L272) could be deferred and folded into the tp_ep and dp passes above.
At scale (hundreds of MoE layers, large DP), reducing the number of collectives per logging step is meaningful.
| if self._pending_clear: | ||
| self._pending_clear = False | ||
| self._clear_storage() | ||
|
|
There was a problem hiding this comment.
[SUGGESTION] The _pending_clear deferred-clear mechanism adds complexity that does not deliver its intended benefit.
The rationale (from the class docstring) is to keep tensor handles valid during CUDA graph replay windows. However, CUDA graph replay does not re-execute Python-side autograd functions — record_fwd and record_bwd are never called during replay. So the tracker never receives new data during replay regardless of whether storage has been cleared, making the deferred-clear protection moot.
For non-CUDA-graph training the extra state (_pending_clear flag + _flush_pending_clear() call in every record_fwd/record_bwd) is pure overhead with no benefit.
Suggestion: Have clear() call _clear_storage() directly and remove _pending_clear and _flush_pending_clear().
| g = parallel_state.get_pipeline_model_parallel_group(check_initialized=False) | ||
| return g | ||
|
|
||
| def report( |
There was a problem hiding this comment.
[SUGGESTION] report() is ~200 lines and mixes several independent responsibilities: tp_ep reduction, dp reduction, cumsum peak computation, pp reduction, TensorBoard/W&B logging, and log-string assembly. This makes it hard to test, read, and debug any single concern.
Suggest splitting into focused private methods:
_reduce_tp_ep(fwd_tensors, balanced_tensors)→ returnstp_ep_overload_reduce_dp(tp_ep_overload)→ returns(overload_avg, overload_max)_compute_max_cum_overload()→ returnsmax_cum_overload_factor_reduce_pp(max_overload, max_cum)→ returns pp-reduced scalars_log_to_writers(writer, wandb_writer, scalars, iteration)→ writes TB/W&B
report() then becomes a thin orchestrator responsible only for call ordering.
| """Tracker for MoE overload factor metrics. | ||
|
|
||
| Records per-layer **tokens on this rank** after dispatch | ||
| (``tokens_per_expert.sum()``) and a pre-dispatch **balanced token count** scalar | ||
| (from ``routing_map.shape[0] * moe_router_topk`` read after ``token_dispatch``), | ||
| via an autograd hook on | ||
| ``dispatched_input``. ``report()`` does ``all_reduce(MAX)`` on per-rank actual totals | ||
| over ``tp_ep_group``, divides by balanced count per rank (from summed local counts / size) | ||
| to get **tp_ep overload** per microbatch entry, then ``all_reduce(AVG)`` and | ||
| ``all_reduce(MAX)`` on that overload across ``dp_group`` before scalar summaries. | ||
| Over the **pipeline-parallel** group, ``max`` and ``max_cum`` scalars are | ||
| ``all_reduce(MAX)`` so every stage agrees on the worst overload; ranks without | ||
| MoE layers contribute ``0``. The **mean** overload scalar is **not** reduced | ||
| across PP (each rank logs its local mean, ``0`` if it recorded nothing). | ||
| ``_fwd_bwd`` / ``_fwd_bwd_balanced`` mirror interleaved fwd/bwd so cumulative | ||
| peaks of actual vs balanced token counts can be compared. | ||
|
|
||
| Lifecycle: set_process_groups() and record_fwd/record_bwd during forward | ||
| (SaveOverloadFactorFunction in MoELayer) → report() at step end | ||
| (sync, aggregate, log, deferred clear) → repeat. | ||
|
|
||
| ``clear()`` only marks storage for reset on the next ``record_fwd`` or | ||
| ``record_bwd`` so tensor handles stay valid until Python runs a recording | ||
| hook again (e.g. across CUDA graph replay windows that skip those hooks). |
There was a problem hiding this comment.
Comments could be refactored with AI for better readability
Example comment refactored by GPT
"""Track MoE overload-factor metrics.
Recorded values
---------------
- Per-layer actual tokens on this rank after dispatch:
tokens_per_expert.sum().
- Per-layer balanced token count before dispatch:
routing_map.shape[0] * moe_router_topk (read after token_dispatch).
- Both values are captured by an autograd hook on dispatched_input.
- _fwd_bwd and _fwd_bwd_balanced mirror interleaved fwd/bwd events so
cumulative peaks of actual vs balanced token counts can be compared.
How report() aggregates
-----------------------
1. In tp_ep_group, run all_reduce(MAX) on per-rank actual totals.
2. Divide by balanced tokens per rank (summed local balanced counts / group size)
to get per-entry tp_ep overload.
3. In dp_group, run all_reduce(AVG) and all_reduce(MAX) on overload
before scalar summaries.
4. In the pipeline-parallel group, max and max_cum use all_reduce(MAX) so
every stage agrees on the worst overload. Ranks without MoE layers
contribute 0.
5. Mean overload is not reduced across PP; each rank logs its local mean
(0 if nothing was recorded).
Lifecycle
---------
set_process_groups() and record_fwd()/record_bwd() are called during forward
(SaveOverloadFactorFunction in MoELayer). report() runs at step end
(sync, aggregate, log, deferred clear), then the cycle repeats.
clear() behavior
----------------
clear() does not immediately reset storage. It marks storage for reset on
the next record_fwd() or record_bwd() so tensor handles stay valid until
Python executes a recording hook again (for example across CUDA graph replay
windows that skip those hooks).
| return grad_output, None, None, None | ||
|
|
||
|
|
||
| def save_overload_factor_to_tracker( |
There was a problem hiding this comment.
[SUGGESTION] save_overload_factor_to_tracker and SaveOverloadFactorFunction are misleadingly named — neither computes nor saves an overload factor. What they actually do is record post-dispatch token counts (actual tokens on this rank + balanced token count) into the tracker. The overload factor itself is only computed later in report().
Suggested renames:
SaveOverloadFactorFunction→RecordDispatchTokenCountsFunctionsave_overload_factor_to_tracker→record_dispatch_token_counts
What does this PR do ?
This PR introduces a utility to log overload factor through
log_overload_factor.MoE overload factor
Overload factor is the ratio of the token count on the most loaded rank in a TP-EP group to the balanced token count per rank—the count each rank would see with perfectly balanced routing. It measures workload imbalance.
moe/avg_overload_factorArithmetic mean of
overload factorover all layer × microbatch x DP slice entries recorded on this rank in the step.moe/max_overload_factorMaximum of
overload factorover all layer × microbatch x DP slice entries. Useful for estimating peak buffer size for intermediate activations in the forward path.moe/max_cum_overload_factorReflects the cumulative fwd/bwd tokensL ratio of peak cumulative actual tokens to peak cumulative balanced count. Useful for estimating how much activation-related memory may need to be retained through backward.
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.